from collections import defaultdict
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
import sklearn

import pandas as pd

from nltk.tokenize import sent_tokenize
import re

from flair.models import SequenceTagger
from flair.data import Sentence

# protest keyword selector
protest_regex = re.compile(r'protest|versamm|demonstr|kundgebung|kampagne|soziale bewegung|hausbesetz|streik|unterschriftensammlung|hasskriminalität|unruhen|aufruhr|aufstand|boykott|riot|aktivis|widerstand|mobilisierung|petition|bürgerinitiative|bürgerbegehren|aufmarsch', re.UNICODE | re.IGNORECASE)

def reformat_df(data_df, filter_size = -1):

    if 'labels' in data_df.columns:
        df = data_df[['text', 'labels']].copy()
        df["labels"] = df["labels"].astype(int)
    else:
        df = data_df[['text']].copy()
        df["labels"] = None

    if filter_size >= 0:
        new_docs = []
        for d in tqdm(df['text']):
            token_text = sent_tokenize(d, language='german')
            keep_s = [0] * len(token_text)
            d_contains_any_keyterm = False
            for i, s in enumerate(token_text):
                if protest_regex.search(s):
                    # keep current sentence
                    d_contains_any_keyterm = True
                    keep_s[i] = 1
                    # print(s)

                    if filter_size > 0:
                        # keep prev and next sentence
                        if i > 0:
                             keep_s[i-1] = 1
                        if i < (len(token_text) - 1):
                            keep_s[i+1] = 1

                        if filter_size == 2:
                            # keep prev-1 and next+1 sentence
                            if i > 1:
                                keep_s[i-2] = 1
                            if i < (len(token_text) - 2):
                                keep_s[i+2] = 1
            new_d = []
            if d_contains_any_keyterm:
                # make sure that headline is included, too
                keep_s[0] = 1
            for i, k in enumerate(keep_s):
                if k:
                    new_d.append(token_text[i])
            if new_d:
                new_d = " ".join(new_d)
            else:
                # keep entire document, if no protest term matches
                new_d = d
        
            new_docs.append(new_d)     
    
        df['text'] = new_docs

    df = df.rename(columns={"text": "text_a", "labels" : "labels"})
    return df
